### #########################################
### Calculates leave-one-out error
### Uses the error-minimizing model (not 1SE)
### Output is a table
### #########################################

here::i_am("newR/06_loo.R")

library(here)
library(brms)
library(tidyverse)
library(abind)

source(here::here("newR", "00_helpers.R"))

dat <- readRDS(here::here("working",
                          "tidy_model_data.rds"))

load(here::here("working",
                "varsel_formulas.rdata"))

pvars <- c("Con_BE", "Lab_BE", "Lib_BE", "Nat_BE", "Oth2_BE")
for (p in pvars) { 
    dat[,p] <- replace(dat[,p],
                       dat[,p] <= 0,
                       1 / 40000)
}

### Make sure everything sums to 1
dat[,pvars] <- dat[,pvars] / rowSums(dat[,pvars])

### Create a matrix as dep. var.
dat$y <- with(dat,
              cbind(Con_BE, Lab_BE, Lib_BE, Nat_BE, Oth2_BE))

common_priors <- c(prior("normal(0, 0.429)", class = "Intercept", dpar = "muLabBE"),
               prior("normal(-0.95, 0.584)", class = "Intercept", dpar = "muLibBE"),
               prior("normal(-2.457, 0.714)", class = "Intercept", dpar = "muNatBE"),
               prior("normal(-0.95, 0.584)", class = "Intercept", dpar = "muOth2BE"),
               prior("normal(11.5, 3.25)", class = "phi"))

### Labour Poll Change is common to both models
common_priors <- c(common_priors,
               prior("normal(0.241, 0.12)", coef = "LabPollChg_sc", dpar = "muLabBE"),
               prior("normal(0.06, 0.0905)", coef = "LabPollChg_sc", dpar = "muLibBE"),
               prior("normal(0.06, 0.0905)", coef = "LabPollChg_sc", dpar = "muNatBE"),
               prior("normal(0.06, 0.0905)", coef = "LabPollChg_sc", dpar = "muOth2BE"))

### Now set up priors
### Start with the CV-minimizing model first
priors_min <- common_priors
coefs <- get_prior(f_min,
          family = dirichlet(),
          data = dat) %>%
    filter(class == "b") %>%
    filter(coef != "") %>%
    pull(coef)

coefs <- setdiff(coefs,
                 "LabPollChg_sc")

dpars <- get_prior(f_min,
          family = dirichlet(),
          data = dat) %>%
    filter(class == "b") %>%
    filter(coef != "") %>%
    pull(dpar) %>%
    unique()

for (k in coefs) {
    for (p in dpars) {
        short_party <- sub("mu(...).?BE", "\\1", p)
        ## If it's a candidacy and the it concerns the party we're modelling
        if (grepl("Cand_BE", k) & grepl(short_party, k)) {
            prior_string <- "normal(0, 1)"
        } else {
            prior_string <- "normal(0, 0.25)"
        }
        
        priors_min <- c(priors_min,
                        set_prior(prior_string, class = "b",
                                  coef = k,
                                  dpar = p))
    }
}


mod <- readRDS(here::here("working",
                          "model_min.rds"))

loo_file <- here::here("working",
                       "loo_errs_min.rds")

if (file.exists(loo_file)) {
    holder <- readRDS(loo_file)
} else {
    holder <- array(data = NA,
                    dim = c(1500, 468, 5))
    
    for (i in seq_len(nrow(dat))) {
        fm2 <- update(mod,
                      newdata = dat[-i,],
                      chains = 5,
                      cores = 5,
                      iter = 1300,
                      warmup = 1000,
                      seed = 43145,
                      control = list(adapt_delta =  0.95,
                                     max_treedepth = 11))
        
        pp <- posterior_predict(fm2,
                                newdata = dat[i,],
                                summary = FALSE)
        holder[,i,] <- pp[,1,]
        
    }

    saveRDS(holder,
            file = here::here("working",
                              "loo_errs_min.rds"))
    
}
        

### ####################################################
### Zero stuff out
### ####################################################

candvars <- paste0(c("Con", "Lab", "Lib", "Nat", "Oth"),
                  "Cand_BE")
candmat <- as.matrix(dat[,candvars])

### Replace adjusted outcome values
y <- y0 <- dat$y
y[which(candmat == 0, arr.ind = TRUE)] <- 0

holder0 <- holder

### Zero out fitted values
for (i in 1:dim(holder)[1]) {
    x <- holder[i,,]
    x[which(candmat == 0, arr.ind = TRUE)] <- 0
    x <- x / rowSums(x)
    holder[i,,] <- x
}

f <- holder
pp <- holder

###########################################################################
### Get mean and median absolute error
###########################################################################

out <- data.frame(Statistic = c("Seats correctly predicted",
                                "Multiclass Brier score",
                                "Mean absolute error",
                                "Median absolute error",
                                "Predictions inside 95% interval"),
                  `LOO` = c(100 * get_pcp3(f, y),
                                  100 * get_brier3(f, y),
                                  100 * get_mae3(f, y),
                                  100 * get_mae3(f, y, type = "median"),
                                  100 * get_calibration3(pp, y)))

saveRDS(out,
        file = here::here("working",
                          "loo-perf.rds"))
